Compare simple difference in functional effects across two conditions¶

Import Python modules. We use polyclonal for the plotting:

In [1]:
import itertools
import math

import altair as alt

import dms_variants.utils

import pandas as pd

import polyclonal
import polyclonal.plot

This notebook is parameterized by papermill. The next cell is tagged as parameters to get the passed parameters.

In [2]:
# this cell is tagged parameters for `papermill` parameterization
site_numbering_map_csv = None
mutation_annotations_csv = None
diffs_csv = None
chart_html = None
corr_chart_html = None
params = None
In [3]:
# Parameters
params = {
    "condition_1": {
        "name": 220210,
        "selections": ["LibA-220210-293T_ACE2-1", "LibA-220210-293T_ACE2-2"],
    },
    "condition_2": {
        "name": 220302,
        "selections": ["LibA-220302-293T_ACE2-1", "LibA-220302-293T_ACE2-2"],
    },
    "avg_method": "median",
    "per_selection_tooltips": True,
    "plot_kwargs": {
        "addtl_slider_stats": {
            "times_seen": 3,
            "difference_std": 2,
            "fraction_pairs_w_mutation": 1,
            "best_effect": -2,
            "220210 effect": None,
            "220302 effect": None,
            "nt changes to codon": 3,
        },
        "addtl_slider_stats_hide_not_filter": [
            "best_effect",
            "220210 effect",
            "220302 effect",
            "nt changes to codon",
        ],
        "addtl_slider_stats_as_max": ["difference_std", "nt changes to codon"],
        "heatmap_max_at_least": 1,
        "heatmap_min_at_least": -1,
        "init_floor_at_zero": False,
        "init_site_statistic": "mean_abs",
        "site_zoom_bar_color_col": "region",
        "slider_binding_range_kwargs": {
            "times_seen": {"step": 1, "min": 1, "max": 25},
            "nt changes to codon": {"step": 1, "min": 1, "max": 3},
        },
    },
}
mutation_annotations_csv = "data/mutation_annotations.csv"
site_numbering_map_csv = "data/site_numbering_map.csv"
diffs_csv = "results/func_effect_diffs/220210_vs_220302_comparison_diffs.csv"
chart_html = "results/func_effect_diffs/220210_vs_220302_comparison_diffs.html"
corr_chart_html = (
    "results/func_effect_diffs/220210_vs_220302_comparison_diffs_corr.html"
)

Read the input data:

In [4]:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
    columns={"reference_site": "site"}
)
assert site_numbering_map[["site", "sequential_site"]].notnull().all().all()
addtl_site_cols = [
    c for c in site_numbering_map.columns if c != "site" and c.endswith("site")
]

if mutation_annotations_csv:
    mutation_annotations = pd.read_csv(mutation_annotations_csv)

condition_1 = params["condition_1"]["name"]
condition_2 = params["condition_2"]["name"]
assert condition_1 != condition_2, f"{condition_1=}, {condition_2=}"
condition_1_selections = params["condition_1"]["selections"]
condition_2_selections = params["condition_2"]["selections"]
assert len(condition_1_selections) == len(set(condition_1_selections))
assert len(condition_2_selections) == len(set(condition_2_selections))
assert len(condition_1_selections), params["condition_1"]
assert len(condition_2_selections), params["condition_2"]
if set(condition_1_selections).intersection(condition_2_selections):
    raise ValueError(
        f"shared selections in {condition_1_selections=} and {condition_2_selections=}"
    )

dfs = []
for c, sels in [
    (condition_1, condition_1_selections),
    (condition_2, condition_2_selections),
]:
    for s in sels:
        dfs.append(
            pd.read_csv(
                f"results/func_effects/by_selection/{s}_func_effects.csv"
            ).assign(
                selection=s,
                condition=c,
                times_seen=lambda x: x["times_seen"].astype("Int64"),
                mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
            )
        )
func_effects = pd.concat(dfs, ignore_index=True)

Correlations among all selections¶

Compute the correlations in the mutation effects across all selections:

In [5]:
# We compute for several times seen values, get those:
try:
    init_times_seen = params["plot_kwargs"]["addtl_slider_stats"]["times_seen"]
except KeyError:
    print("No times seen in params, using a value of 3")
    init_times_seen = 3

# do analysis for each "times_seen"
func_effects_for_corr = pd.concat(
    [
        func_effects.query("times_seen >= @t", engine="python").assign(min_times_seen=t)
        for t in [1, init_times_seen, 2 * init_times_seen]
    ]
)

corrs = (
    dms_variants.utils.tidy_to_corr(
        df=func_effects_for_corr,
        sample_col="selection",
        label_col="mutation",
        value_col="functional_effect",
        group_cols=["min_times_seen"],
    )
    .assign(
        r2=lambda x: x["correlation"] ** 2,
        min_times_seen=lambda x: "min times seen " + x["min_times_seen"].astype(str),
    )
    .rename(columns={"correlation": "r"})
)

corr_chart = (
    alt.Chart(corrs)
    .encode(
        alt.X("selection_1", title=None),
        alt.Y("selection_2", title=None),
        column=alt.Column("min_times_seen", title=None),
        color=alt.Color("r2", scale=alt.Scale(zero=True)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if c in {"r2", "r"} else c
            for c in ["selection_1", "selection_2", "r2", "r"]
        ],
    )
    .mark_rect(stroke="black")
    .properties(
        width=alt.Step(15),
        height=alt.Step(15),
        title="Per-selection correlation in mutation functional effects",
    )
    .configure_axis(labelLimit=500)
)

display(corr_chart)

print(
    f"\nSelections for {condition_1}: {condition_1_selections}\n"
    f"Selections for {condition_2}: {condition_2_selections}\n"
)
Selections for 220210: ['LibA-220210-293T_ACE2-1', 'LibA-220210-293T_ACE2-2']
Selections for 220302: ['LibA-220302-293T_ACE2-1', 'LibA-220302-293T_ACE2-2']

Average functional effects for each condition¶

Average the functional effects for each condition using the specified averaging method, then print the correlation between these average functional effects at several times seen:

In [6]:
avg_method = params["avg_method"]
assert avg_method in {"mean", "median"}, avg_method

avg_func_effects = (
    func_effects.groupby(
        ["condition", "site", "wildtype", "mutant", "mutation"], as_index=False
    )
    .aggregate(
        effect=pd.NamedAgg("functional_effect", avg_method),
        times_seen=pd.NamedAgg("times_seen", "sum"),
        n_selections=pd.NamedAgg("site", "count"),
    )
    .assign(
        times_seen=lambda x: (x["times_seen"] / x["n_selections"]).where(
            x["mutant"] != x["wildtype"],
            pd.NA,
        )
    )
)

avg_func_effects_for_corr = pd.concat(
    [
        avg_func_effects.query("times_seen >= @t", engine="python").assign(
            min_times_seen=t
        )
        for t in [1, init_times_seen, 2 * init_times_seen]
    ]
)
print("Correlation between average functional effects across conditions:")
display(
    dms_variants.utils.tidy_to_corr(
        df=avg_func_effects_for_corr,
        sample_col="condition",
        label_col="mutation",
        value_col="effect",
        group_cols=["min_times_seen"],
    )
    .assign(
        r2=lambda x: x["correlation"] ** 2,
        min_times_seen=lambda x: "min times seen " + x["min_times_seen"].astype(str),
    )
    .rename(columns={"correlation": "r"})
    .query("condition_1 != condition_2")
    .reset_index(drop=True)
    .groupby("min_times_seen")
    .first()
    .round(3)
)
Correlation between average functional effects across conditions:
condition_1 condition_2 r r2
min_times_seen
min times seen 1 220302 220210 0.867 0.752
min times seen 3 220302 220210 0.879 0.773
min times seen 6 220302 220210 0.848 0.718

Compute pairwise differences¶

Compute pairwise differences in effects between all pairs of condition 1 selections versus condition 2 selections. For each comparison, we compute the times seen as the mean between the two selections being compared.

We then compute the average (using the specified average method) difference across comparisons, the mean times seen, and the fraction of comparisons in which a difference can be computed:

In [7]:
# compute differences for all individual pairs
diffs_all = []
for sel1, sel2 in itertools.product(condition_1_selections, condition_2_selections):
    df1 = func_effects.query("selection == @sel1")[
        ["wildtype", "site", "mutant", "times_seen", "functional_effect"]
    ]
    df2 = func_effects.query("selection == @sel2")[
        ["wildtype", "site", "mutant", "times_seen", "functional_effect"]
    ]
    diffs_all.append(
        df1.merge(df2, on=["wildtype", "site", "mutant"], validate="1:1")
        .assign(
            times_seen=lambda x: (x["times_seen_x"] + x["times_seen_y"]) / 2,
            difference=lambda x: x["functional_effect_x"] - x["functional_effect_y"],
        )[["wildtype", "site", "mutant", "times_seen", "difference"]]
        .assign(comparison=f"{sel1} vs {sel2}")
    )

# compute average differences across pairs
diffs = (
    pd.concat(diffs_all, ignore_index=True)
    .groupby(["wildtype", "site", "mutant"], as_index=False)
    .aggregate(
        difference=pd.NamedAgg("difference", avg_method),
        difference_std=pd.NamedAgg("difference", "std"),
        times_seen=pd.NamedAgg("times_seen", "mean"),
        fraction_pairs_w_mutation=pd.NamedAgg(
            "difference",
            lambda s: len(s)
            / (len(condition_1_selections) * len(condition_2_selections)),
        ),
    )
)

# add other relevant stuff to data frame of differences
diffs = (
    diffs
    # add average effects in each condition
    .merge(
        avg_func_effects.pivot_table(
            index=["site", "wildtype", "mutant"],
            values="effect",
            columns="condition",
        )
        .reset_index()
        .assign(best_effect=lambda x: x[[condition_1, condition_2]].max(axis=1))
        .rename(columns={c: f"{c} effect" for c in [condition_1, condition_2]}),
        on=["wildtype", "site", "mutant"],
        validate="one_to_one",
    )
    # add per-selection effects (times seen)
    .merge(
        func_effects.assign(
            effect_times_seen=lambda x: (
                x["functional_effect"].map(lambda e: f"{e:.2f}")
                + (" (" + x["times_seen"].astype(str) + ")").where(
                    x["mutant"] != x["wildtype"],
                    "",
                )
            )
        )
        .pivot_table(
            index=[
                "site",
                "wildtype",
                "mutant",
            ],
            values="effect_times_seen",
            columns="selection",
            aggfunc=lambda s: ",".join(s),
        )[condition_1_selections + condition_2_selections]
        .reset_index(),
        on=["wildtype", "site", "mutant"],
        validate="one_to_one",
    )
    # sort values
    .sort_values(["site", "mutant"]).reset_index(drop=True)
)

print(f"Writing differences to {diffs_csv}")
diffs.to_csv(diffs_csv, index=False, float_format="%.4g")
Writing differences to results/func_effect_diffs/220210_vs_220302_comparison_diffs.csv

Make scatter plot of comparisons, applying times seen filter:

In [8]:
print(f"Correlating differences for times_seen of {init_times_seen}")

diffs_all_df = (
    pd.concat(diffs_all)
    .query("times_seen >= @init_times_seen")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    .rename(columns={"comparison": "selection"})
)

corr_panels = []
for sel1, sel2 in itertools.combinations(sorted(diffs_all_df["selection"].unique()), 2):
    corr_df = (
        diffs_all_df.query("selection == @sel1")[["mutation", "difference"]]
        .rename(columns={"difference": sel1})
        .merge(
            diffs_all_df.query("selection == @sel2")[["mutation", "difference"]].rename(
                columns={"difference": sel2}
            ),
            validate="one_to_one",
        )
    )
    n = len(corr_df)
    r = corr_df[[sel1, sel2]].corr().values[1, 0]
    corr_panels.append(
        alt.Chart(corr_df)
        .encode(
            alt.X(sel1, scale=alt.Scale(nice=False, padding=4)),
            alt.Y(sel2, scale=alt.Scale(nice=False, padding=4)),
            tooltip=[
                "mutation",
                alt.Tooltip(sel1, format=".3f"),
                alt.Tooltip(sel2, format=".3f"),
            ],
        )
        .mark_circle(color="black", size=30, opacity=0.25)
        .properties(
            width=160,
            height=160,
            title=alt.TitleParams(
                f"R = {r:.2f}, N = {n}", fontSize=11, fontWeight="normal", dy=2
            ),
        )
    )

ncols = 4
corr_rows = []
for irow in range(int(math.ceil(len(corr_panels) / ncols))):
    corr_rows.append(
        alt.hconcat(
            *[
                corr_panels[irow * ncols + icol]
                for icol in range(min(ncols, len(corr_panels[irow * ncols :])))
            ]
        )
    )
alt.vconcat(*corr_rows).configure_axis(grid=False)
Correlating differences for times_seen of 3
Out[8]:

Make a scatter plot comparing the conditions¶

Make a correlation plot between the two conditions with informative tooltips and slider bars:

In [9]:
mutation_selection = alt.selection_point(
    on="mouseover", fields=["mutation"], empty=False
)

if mutation_annotations_csv:
    if not {"site", "mutant"}.issubset(mutation_annotations.columns):
        raise ValueError(f"{mutation_annotations.columns=} lacks 'site', 'mutant'")
    if set(mutation_annotations.columns).intersection(diffs.columns) != {
        "site",
        "mutant",
    }:
        raise ValueError(
            f"{mutation_annotations.columns=} shares columns with {diffs.columns=}"
        )
    diffs = diffs.merge(
        mutation_annotations,
        on=["site", "mutant"],
        how="left",
        validate="many_to_one",
    )
    for col in mutation_annotations.columns:
        if col not in {"site", "mutant"}:
            diffs[col] = diffs[col].where(diffs["wildtype"] != diffs["mutant"], pd.NA)

corr_diffs = (
    diffs.query("wildtype != mutant")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
    )
    .drop(columns=["wildtype", "site", "mutant"])
)
corr_diffs = corr_diffs[
    ["mutation"] + [c for c in corr_diffs.columns if c != "mutation"]
]

plot_kwargs = params["plot_kwargs"].copy()
if "slider_binding_range_kwargs" not in plot_kwargs:
    plot_kwargs["slider_binding_range_kwargs"] = {}
if "addtl_slider_stats_as_max" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats_as_max"] = []

sliders = {
    stat: alt.param(
        value=(
            plot_kwargs["addtl_slider_stats"][stat]
            if (
                "addtl_slider_stats" in plot_kwargs
                and stat in plot_kwargs["addtl_slider_stats"]
            )
            else (
                corr_diffs[stat].max()
                if stat == "difference_std"
                else corr_diffs[stat].min()
            )
        ),
        bind=alt.binding_range(
            **(
                {
                    "name": (
                        f"maximum {stat}"
                        if stat in plot_kwargs["addtl_slider_stats_as_max"]
                        else f"minimum {stat}"
                    ),
                    "min": corr_diffs[stat].min(),
                    "max": corr_diffs[stat].max(),
                }
                | (
                    plot_kwargs["slider_binding_range_kwargs"][stat]
                    if stat in plot_kwargs["slider_binding_range_kwargs"]
                    else {}
                )
            )
        ),
    )
    for stat in plot_kwargs["addtl_slider_stats"]
}

corr_chart = (
    alt.Chart(corr_diffs)
    .add_params(mutation_selection)
    .encode(
        alt.X(
            f"{condition_1} effect", scale=alt.Scale(nice=False, zero=False, padding=5)
        ),
        alt.Y(
            f"{condition_2} effect", scale=alt.Scale(nice=False, zero=False, padding=5)
        ),
        strokeWidth=alt.condition(mutation_selection, alt.value(2), alt.value(0)),
        size=alt.condition(mutation_selection, alt.value(70), alt.value(45)),
        tooltip=[
            alt.Tooltip(c, format=".3g") if corr_diffs[c].dtype == float else c
            for c in corr_diffs.columns
        ],
    )
    .mark_circle(fill="black", fillOpacity=0.35, stroke="red")
    .properties(width=275, height=275)
    .configure_axis(grid=False)
)

for stat, slider in sliders.items():
    if stat in plot_kwargs["addtl_slider_stats_as_max"]:
        corr_chart = corr_chart.add_params(slider).transform_filter(
            alt.datum[stat] <= slider
        )
    else:
        corr_chart = corr_chart.add_params(slider).transform_filter(
            alt.datum[stat] >= slider
        )

print(f"Saving to {corr_chart_html=}")
corr_chart.save(corr_chart_html)

corr_chart
Saving to corr_chart_html='results/func_effect_diffs/220210_vs_220302_comparison_diffs_corr.html'
Out[9]:

Make interactive chart¶

Set up keyword arguments to https://jbloomlab.github.io/polyclonal/polyclonal.plot.html#polyclonal.plot.lineplot_and_heatmap if they are not already specified:

In [10]:
if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {}

if "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 3

if "difference_std" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["difference_std"] = diffs["difference_std"].max()
    if "addtl_slider_stats_as_max" not in plot_kwargs:
        plot_kwargs["addtl_slider_stats_as_max"] = ["difference_std"]
    else:
        plot_kwargs["addtl_slider_stats_as_max"].append("difference_std")
elif "addtl_slider_stats_as_max" not in plot_kwargs:
    raise ValueError(
        "You specified `difference_std` in `addtl_slider_stats` but did not add it to "
        "`addtl_slider_stats_as_max`. If you really do not want `difference_std` in "
        "`addtl_slider_stats_as_max`, then specify that list without it."
    )

if "fraction_pairs_w_mutation" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["fraction_pairs_w_mutation"] = 0.5

if "site_zoom_bar_color_col" in plot_kwargs:
    if plot_kwargs["site_zoom_bar_color_col"] in diffs.columns:
        pass
    elif plot_kwargs["site_zoom_bar_color_col"] in site_numbering_map.columns:
        diffs = diffs.merge(
            site_numbering_map[["site", plot_kwargs["site_zoom_bar_color_col"]]],
            on="site",
            validate="many_to_one",
            how="left",
        )

if "addtl_tooltip_stats" not in plot_kwargs:
    plot_kwargs["addtl_tooltip_stats"] = []
for c in ["difference_std"] + addtl_site_cols:
    if c not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append(c)

if "sequential_site" not in diffs.columns:
    diffs = diffs.merge(
        site_numbering_map[["site", *addtl_site_cols]],
        on="site",
        validate="many_to_one",
        how="left",
    )
if any(diffs["site"] != diffs["sequential_site"]):
    if "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

if params["per_selection_tooltips"]:
    assert set(condition_1_selections + condition_2_selections).issubset(diffs.columns)
    plot_kwargs["addtl_tooltip_stats"] += [
        s
        for s in condition_1_selections + condition_2_selections
        if s not in plot_kwargs["addtl_tooltip_stats"]
    ]

if "alphabet" not in plot_kwargs:
    plot_kwargs["alphabet"] = [
        a
        for a in polyclonal.alphabets.biochem_order_aas(polyclonal.AAS_WITHSTOP_WITHGAP)
        if a in set(diffs["mutant"])
    ]

if "sites" not in plot_kwargs:
    plot_kwargs["sites"] = site_numbering_map.sort_values("sequential_site")[
        "site"
    ].tolist()

Now make the interactive heatmap:

In [11]:
assert "_dummy" not in diffs.columns

chart = polyclonal.plot.lineplot_and_heatmap(
    data_df=diffs.assign(_dummy="dummy"),
    stat_col="difference",
    category_col="_dummy",
    **plot_kwargs,
)

display(chart)

print(f"\nSaving to {chart_html}")
chart.save(chart_html)
Saving to results/func_effect_diffs/220210_vs_220302_comparison_diffs.html
In [ ]: